import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

import os
import pickle

import seaborn as sns


mpl.rcParams.update({'font.size': 12})
palette = sns.color_palette("tab10")


def ema(scalars, weight=0.5):
    last = scalars[0]
    smoothed = list()
    for point in scalars:
        smoothed_val = last * weight + (1 - weight) * point
        smoothed.append(smoothed_val)
        last = smoothed_val

    return np.array(smoothed)


def smooth(scalars, weight=0.1, start=0):
    return np.concatenate((scalars[:start], ema(scalars[start:], weight)))


def smooth_upper(scalars, variance, weight=0.1, start=0):
    return smooth(scalars + variance, weight=weight, start=start)


def smooth_lower(scalars, variance, min_val, weight=0.1, start=0):
    return smooth(np.maximum(scalars - variance, min_val),
                  weight=weight, start=start)


def plot_speed_up():
    results_dir = "../results_time_series"
    plot_dir = "../figures/speed_up"
    os.makedirs(plot_dir, exist_ok=True)

    n_months_list = [6, 12]
    global_lr_list = [0.1]
    local_lr_list = [0.01, 0.001]
    momentum_coef_list = [0.5]
    local_steps_list = [100]
    n_clients_list = [6, 24, 240]

    algos = ["local_sgd", "minibatch_sgd", "local_sgd_momentum"]
    metrics = ["loss", "grad_norm"]
    colors = {
        n_clients_list[i]: palette[i+3]
        for i in range(len(n_clients_list))
    }
    labels = {
        n_clients: f"{n_clients} clients"
        for n_clients in n_clients_list
    }

    # For each algorithm, plot the value of the last iterate w.r.t the number of clients
    for n_months in n_months_list:
        for global_lr in global_lr_list:
            for local_lr in local_lr_list:
                for momentum_coef in momentum_coef_list:
                    for local_steps in local_steps_list:
                        values = {
                            metric: {
                                algo: []
                                for algo in algos
                            } for metric in metrics
                        }

                        # Open the file and load the results
                        for n_clients in n_clients_list:
                            file_name = f"local_lr={local_lr},global_lr={global_lr},momentum={momentum_coef},local_steps={local_steps},n_clients={n_clients},n_months={n_months},regularization=True"
                            with open(f"{results_dir}/{file_name}.pkl", "rb") as f:
                                results = pickle.load(f)

                            # Read the results
                            for metric in metrics:
                                for algo in algos:
                                    values[metric][algo].append(results[algo][metric])

                        # Plot
                        for metric in metrics:
                            for algo in algos:
                                fig, ax = plt.subplots(1, 1, figsize=(4, 3))
                                ax.set_yscale("log")
                                plt.xlabel("Communication rounds")

                                for i in range(len(n_clients_list)):
                                    data = values[metric][algo][i] # shape n_runs (10) x n_communications
                                    mean = data.mean(axis=0)
                                    std = data.std(axis=0)
                                    min_val = data.min(axis=0)
                                    ci = 1.96 * std / np.sqrt(data.shape[0])

                                    # plot with confidence interval
                                    plt_idx = np.linspace(
                                        start=0,
                                        stop=data.shape[1] - 1,
                                        num=min(data.shape[1], 200),
                                        dtype=int
                                    )
                                    plt.plot(plt_idx, mean[plt_idx], color=colors[n_clients_list[i]],
                                             label=labels[n_clients_list[i]])
                                    plt.fill_between(
                                        plt_idx,
                                        np.maximum((mean-ci), min_val)[plt_idx],
                                        (mean+ci)[plt_idx],
                                        color=colors[n_clients_list[i]],
                                        edgecolor=colors[n_clients_list[i]],
                                        alpha=0.5
                                    )

                                    if algo == "local_sgd_momentum":
                                        plt.legend(loc="upper right")

                                # Save
                                plot_subdir = f"{plot_dir}/n_months={n_months}/global_lr={global_lr}/local_lr={local_lr}/momentum_coef={momentum_coef}/regularization=True/metric={metric}/K={local_steps}"
                                os.makedirs(plot_subdir, exist_ok=True)
                                plt.savefig(os.path.join(plot_subdir, f"{algo}.pdf"),
                                            bbox_inches="tight")
                                plt.close()


def plot():
    results_dir = "../results_time_series"
    plot_dir = "../figures/heterogeneity"

    os.makedirs(plot_dir, exist_ok=True)

    n_months_list = [6, 12]
    global_lr_list = [0.1]
    local_lr_list = [0.01, 0.001]
    momentum_coef_list = [0.5]
    local_steps_list = [10, 50, 100]
    n_clients_list = [6, 24, 240]
    regularization_list = [True]

    algos = ["local_sgd", "minibatch_sgd", "local_sgd_momentum"]
    metrics = ["loss", "grad_norm"]
    colors = {
        algos[i]: palette[i]
        for i in range(len(algos))
    }
    labels = {
        "local_sgd": "Local SGD",
        "minibatch_sgd": "Minibatch SGD",
        "local_sgd_momentum": "Local SGD Momentum",
    }

    for n_months in n_months_list:
        for n_clients in n_clients_list:
            for global_lr in global_lr_list:
                for local_lr in local_lr_list:
                    for momentum_coef in momentum_coef_list:
                        for regularization in regularization_list:
                            for local_steps in local_steps_list:
                                for metric_name in metrics:
                                    fig, ax = plt.subplots(1, 1, figsize=(4, 3))
                                    ax.set_yscale("log", base=10)
                                    plt.xlabel("Communication rounds")

                                    file_name = f"local_lr={local_lr},global_lr={global_lr},momentum={momentum_coef},local_steps={local_steps},n_clients={n_clients},n_months={n_months},regularization={regularization}"
                                    with open(os.path.join(results_dir, f"{file_name}.pkl"), "rb") as f:
                                        results = pickle.load(f)

                                    for algo in algos:
                                        data = results[algo][metric_name] # shape n_runs (10) x n_communications
                                        mean = data.mean(axis=0)
                                        std = data.std(axis=0)
                                        min_val = data.min(axis=0)
                                        ci = 1.96 * std / np.sqrt(data.shape[0])

                                        # plot with confidence interval
                                        plt_idx = np.linspace(
                                            start=0,
                                            stop=data.shape[1]-1,
                                            num=min(data.shape[1], 200),
                                            dtype=int
                                        )
                                        plt.plot(plt_idx, mean[plt_idx], color=colors[algo],
                                                 label=labels[algo])
                                        plt.fill_between(
                                            plt_idx,
                                            np.maximum((mean-ci), min_val)[plt_idx],
                                            (mean+ci)[plt_idx],
                                            color=colors[algo],
                                            edgecolor=colors[algo],
                                            alpha=0.2,
                                        )

                                    # plot the legend
                                    if local_steps == 100:
                                        plt.legend(loc="upper right", fontsize="small")

                                    plot_subdir = f"{plot_dir}/n_months={n_months}/n_clients={n_clients}/global_lr={global_lr}/local_lr={local_lr}/momentum_coef={momentum_coef}/regularization={regularization}/metric={metric_name}"
                                    os.makedirs(plot_subdir, exist_ok=True)
                                    plt.savefig(os.path.join(plot_subdir, f"K={local_steps}.pdf"),
                                                bbox_inches="tight")
                                    plt.close()


if __name__ == "__main__":
    plot()
    plot_speed_up()